#ifndef EMP_F2K_H__
#define EMP_F2K_H__
#include "emp-tool/utils/block.h"

namespace emp {
/* multiplication in galois field without reduction */
#ifdef __x86_64__
__attribute__((target("sse2,pclmul")))
inline void mul128(__m128i a, __m128i b, __m128i *res1, __m128i *res2) {
	__m128i tmp3, tmp4, tmp5, tmp6;
	tmp3 = _mm_clmulepi64_si128(a, b, 0x00);
	tmp4 = _mm_clmulepi64_si128(a, b, 0x10);
	tmp5 = _mm_clmulepi64_si128(a, b, 0x01);
	tmp6 = _mm_clmulepi64_si128(a, b, 0x11);

	tmp4 = _mm_xor_si128(tmp4, tmp5);
	tmp5 = _mm_slli_si128(tmp4, 8);
	tmp4 = _mm_srli_si128(tmp4, 8);
	tmp3 = _mm_xor_si128(tmp3, tmp5);
	tmp6 = _mm_xor_si128(tmp6, tmp4);
	// initial mul now in tmp3, tmp6
	*res1 = tmp3;
	*res2 = tmp6;
}
#elif __aarch64__
inline void mul128(__m128i a, __m128i b, __m128i *res1, __m128i *res2) {
	__m128i tmp3, tmp4, tmp5, tmp6;
	poly64_t a_lo = (poly64_t)vget_low_u64(vreinterpretq_u64_m128i(a));
	poly64_t a_hi = (poly64_t)vget_high_u64(vreinterpretq_u64_m128i(a));
	poly64_t b_lo = (poly64_t)vget_low_u64(vreinterpretq_u64_m128i(b));
	poly64_t b_hi = (poly64_t)vget_high_u64(vreinterpretq_u64_m128i(b));
	tmp3 = (__m128i)vmull_p64(a_lo, b_lo);
	tmp4 = (__m128i)vmull_p64(a_hi, b_lo);
	tmp5 = (__m128i)vmull_p64(a_lo, b_hi);
	tmp6 = (__m128i)vmull_p64(a_hi, b_hi);

	tmp4 = _mm_xor_si128(tmp4, tmp5);
	tmp5 = _mm_slli_si128(tmp4, 8);
	tmp4 = _mm_srli_si128(tmp4, 8);
	tmp3 = _mm_xor_si128(tmp3, tmp5);
	tmp6 = _mm_xor_si128(tmp6, tmp4);
	// initial mul now in tmp3, tmp6
	*res1 = tmp3;
	*res2 = tmp6;
}
#endif

/* galois field reduction with reflection I/O*/
#ifdef __x86_64__
__attribute__((target("sse2")))
#endif
//https://www.intel.com/content/dam/develop/public/us/en/documents/carry-less-multiplication-instruction.pdf figure 5
inline block reduce_reflect(__m128i tmp3, __m128i tmp6) {//3 is low, 6 is high
	__m128i tmp2, tmp4, tmp5, tmp7, tmp8, tmp9;
	tmp7 = _mm_srli_epi32(tmp3, 31);
	tmp8 = _mm_srli_epi32(tmp6, 31);
	tmp3 = _mm_slli_epi32(tmp3, 1);
	tmp6 = _mm_slli_epi32(tmp6, 1);

	tmp9 = _mm_srli_si128(tmp7, 12);
	tmp8 = _mm_slli_si128(tmp8, 4);
	tmp7 = _mm_slli_si128(tmp7, 4);
	tmp3 = _mm_or_si128(tmp3, tmp7);
	tmp6 = _mm_or_si128(tmp6, tmp8);
	tmp6 = _mm_or_si128(tmp6, tmp9);

	tmp7 = _mm_slli_epi32(tmp3, 31);
	tmp8 = _mm_slli_epi32(tmp3, 30);
	tmp9 = _mm_slli_epi32(tmp3, 25);
	tmp7 = _mm_xor_si128(tmp7, tmp8);
	tmp7 = _mm_xor_si128(tmp7, tmp9);
	tmp8 = _mm_srli_si128(tmp7, 4);
	tmp7 = _mm_slli_si128(tmp7, 12);
	tmp3 = _mm_xor_si128(tmp3, tmp7);

	tmp2 = _mm_srli_epi32(tmp3, 1);
	tmp4 = _mm_srli_epi32(tmp3, 2);
	tmp5 = _mm_srli_epi32(tmp3, 7);
	tmp2 = _mm_xor_si128(tmp2, tmp4);
	tmp2 = _mm_xor_si128(tmp2, tmp5);
	tmp2 = _mm_xor_si128(tmp2, tmp8);
	tmp3 = _mm_xor_si128(tmp3, tmp2);
	return _mm_xor_si128(tmp6, tmp3);
}

/* galois field reduction without reflection*/
#ifdef __x86_64__
__attribute__((target("sse2")))
#endif
//https://www.intel.com/content/dam/develop/public/us/en/documents/carry-less-multiplication-instruction.pdf figure 7
inline block reduce(__m128i tmp3, __m128i tmp6) {//3 is low, 6 is high
	__m128i tmp7, tmp8, tmp9, tmp10, tmp11, tmp12;
	__m128i XMMMASK = _mm_setr_epi32(0xffffffff, 0x0, 0x0, 0x0);
	tmp7 = _mm_srli_epi32(tmp6, 31); 
	tmp8 = _mm_srli_epi32(tmp6, 30); 
	tmp9 = _mm_srli_epi32(tmp6, 25);

	tmp7 = _mm_xor_si128(tmp7, tmp8); 
	tmp7 = _mm_xor_si128(tmp7, tmp9);

	tmp8 = _mm_shuffle_epi32(tmp7, 147);

	tmp7 = _mm_and_si128(XMMMASK, tmp8);
	tmp8 = _mm_andnot_si128(XMMMASK, tmp8);
	tmp3 = _mm_xor_si128(tmp3, tmp8);
	tmp6 = _mm_xor_si128(tmp6, tmp7);

	tmp10 = _mm_slli_epi32(tmp6, 1);
	tmp3 = _mm_xor_si128(tmp3, tmp10);
	tmp11 = _mm_slli_epi32(tmp6, 2);
	tmp3 = _mm_xor_si128(tmp3, tmp11); 
	tmp12 = _mm_slli_epi32(tmp6, 7);
	tmp3 = _mm_xor_si128(tmp3, tmp12);
	return _mm_xor_si128(tmp3, tmp6);
}


inline void gfmul (__m128i a, __m128i b, __m128i *res) {
	block r1, r2;
	mul128(a, b, &r1, &r2);
	*res = reduce(r1, r2);
}

inline void gfmul_reflect (__m128i a, __m128i b, __m128i *res) {
	block r1, r2;
	mul128(a, b, &r1, &r2);
	*res = reduce_reflect(r1, r2);
}


/* inner product of two galois field vectors with reduction */
inline void vector_inn_prdt_sum_red(block *res, const block *a, const block *b, int sz) {
	block r = zero_block;
	block r1;
	for(int i = 0; i < sz; i++) {
		gfmul(a[i], b[i], &r1);
		r = r ^ r1;
	}
	*res = r;
}

/* inner product of two galois field vectors with reduction */
template<int N>
inline void vector_inn_prdt_sum_red(block *res, block const *a, const block *b) {
	vector_inn_prdt_sum_red(res, a, b, N);
}

/* inner product of two galois field vectors without reduction */
inline void vector_inn_prdt_sum_no_red(block *res, const block *a, const block *b, int sz) {
	block r1 = zero_block, r2 = zero_block;
	block r11, r12;
	for(int i = 0; i < sz; i++) {
		mul128(a[i], b[i], &r11, &r12);
		r1 = r1 ^ r11;
		r2 = r2 ^ r12;
	}
	res[0] = r1;
	res[1] = r2;
}

/* inner product of two galois field vectors without reduction */
template<int N>
inline void vector_inn_prdt_sum_no_red(block *res, const block *a, const block *b) {
	vector_inn_prdt_sum_no_red(res, a, b, N);
}

/* coefficients of almost universal hash function */
inline void uni_hash_coeff_gen(block* coeff, block seed, int sz) {
	// Handle the case with small `sz`
	coeff[0] = seed;
	if(sz == 1) return;

	gfmul(seed, seed, &coeff[1]);
	if(sz == 2) return;

	gfmul(coeff[1], seed, &coeff[2]);
	if(sz == 3) return;

	block multiplier;
	gfmul(coeff[2], seed, &multiplier);
	coeff[3] = multiplier;
	if(sz == 4) return;

	// Computing the rest with a batch of 4
	int i = 4;
	for(; i < sz - 3; i += 4) {
		gfmul(coeff[i - 4], multiplier, &coeff[i]);
		gfmul(coeff[i - 3], multiplier, &coeff[i + 1]);
		gfmul(coeff[i - 2], multiplier, &coeff[i + 2]);
		gfmul(coeff[i - 1], multiplier, &coeff[i + 3]);
	}

	// Cleaning up with the rest
	int remainder = sz % 4;
	if(remainder != 0) {
		i = sz - remainder;
		for(; i < sz; ++i)
			gfmul(coeff[i - 1], seed, &coeff[i]);
	}
}

/* coefficients of almost universal hash function */
template<int N>
inline void uni_hash_coeff_gen(block* coeff, block seed) {
	uni_hash_coeff_gen(coeff, seed, N);	
}

/* packing in Galois field (v[i] * X^i for v of size 128) */
class GaloisFieldPacking {
	public:
		block base[128];

		GaloisFieldPacking() {
			packing_base_gen();
		}

		~GaloisFieldPacking() {

		}

		void packing_base_gen() {
			uint64_t a = 0, b = 1;
			for(int i = 0; i < 64; i+=4) {
				base[i] = _mm_set_epi64x(a, b);
				base[i+1] = _mm_set_epi64x(a, b<<1);
				base[i+2] = _mm_set_epi64x(a, b<<2);
				base[i+3] = _mm_set_epi64x(a, b<<3);
				b <<= 4;
			}
			a = 1, b = 0;
			for(int i = 64; i < 128; i+=4) {
				base[i] = _mm_set_epi64x(a, b);
				base[i+1] = _mm_set_epi64x(a<<1, b);
				base[i+2] = _mm_set_epi64x(a<<2, b);
				base[i+3] = _mm_set_epi64x(a<<3, b);
				a <<= 4;
			}
		}

		void packing(block *res, block *data) {
			vector_inn_prdt_sum_red(res, data, base, 128);
		}
};

/* XOR of all elements in a vector */
inline void vector_self_xor(block *sum, block *data, int sz) {
	block res[4];
	res[0] = zero_block;
	res[1] = zero_block;
	res[2] = zero_block;
	res[3] = zero_block;
	for(int i = 0; i < (sz/4)*4; i+=4) {
		res[0] = data[i] ^ res[0];
		res[1] = data[i+1] ^ res[1];
		res[2] = data[i+2] ^ res[2];
		res[3] = data[i+3] ^ res[3];
	}
	for(int i = (sz/4)*4, j = 0; i < sz; ++i, ++j)
		res[j] = data[i] ^ res[j];
	res[0] = res[0] ^ res[1];
	res[2] = res[2] ^ res[3];
	*sum = res[0] ^ res[2];
}

/* XOR of all elements in a vector */
template<int N>
inline void vector_self_xor(block *sum, block *data) {
	vector_self_xor(sum, data, N);
}
}
#endif
